"""
Phase 3: Capital Deepening Regressions
=======================================
Does age structure predict capital-per-worker growth, investment, and TFP?
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def run_model(df, y_var, x_vars, label):
    """Run PanelGLS, print and return results."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[[c for c in cols if c in df.columns]].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    X_data = sub[[v for v in x_vars if v in sub.columns]].values
    actual_x = [v for v in x_vars if v in sub.columns]
    gls.fit(sub[y_var].values, X_data, sub['iso3'].values, sub['year'].values)

    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f}, ρ={gls.rho:.3f})")
    row = {'model': label, 'y_var': y_var, 'n_obs': gls.n_obs,
           'n_countries': gls.n_countries, 'r_squared': gls.r_squared, 'rho': gls.rho}
    for i, name in enumerate(actual_x):
        s = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:10.5f} ({gls.se[i]:.5f}) {s}")
        row[f'{name}_coef'] = gls.beta[i]
        row[f'{name}_se'] = gls.se[i]
        row[f'{name}_p'] = gls.pvalues[i]
    return row


def main():
    print("=" * 70)
    print("PHASE 3: CAPITAL DEEPENING REGRESSIONS")
    print("=" * 70)

    df = pd.read_csv(DATA / "deepening_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    # Construct OADR² if not present
    if 'old_dep' in df.columns and 'old_dep_sq' not in df.columns:
        df['old_dep_sq'] = df['old_dep'] ** 2

    results = []
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']

    # ── Dependent variables ──
    outcomes = {
        'delta_log_kl': 'Capital Deepening Growth (Δlog K/L)',
        'gross_fixed_investment_gdp': 'Investment/GDP',
        'delta_log_tfp': 'TFP Growth (Δlog TFP)',
        'mpk_proxy': 'MPK Proxy',
        'capital_output_ratio': 'Capital-Output Ratio',
    }

    for y_var, y_label in outcomes.items():
        if y_var not in df.columns:
            print(f"\nSkipping {y_label}: variable not in panel")
            continue

        print(f"\n{'='*50}")
        print(f"Outcome: {y_label}")
        print(f"{'='*50}")

        # Model A: Z → outcome
        r = run_model(df, y_var, ['Z_1', 'Z_2', 'Z_3'] + controls,
                      f'{y_label}: Z')
        if r: results.append(r)

        # Model B: OADR + OADR²
        if 'old_dep' in df.columns:
            r = run_model(df, y_var, ['old_dep', 'old_dep_sq'] + controls,
                          f'{y_label}: OADR')
            if r: results.append(r)

        # Model C: Z × KAOPEN → channel demographic capital into deepening?
        r = run_model(df, y_var,
                      ['Z_1', 'Z_2', 'Z_3',
                       'Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen'] + controls,
                      f'{y_label}: Z×KAOPEN')
        if r: results.append(r)

        # Model D: Z × NFA (creditor vs debtor asymmetry)
        if 'nfa_positive' in df.columns:
            r = run_model(df, y_var,
                          ['Z_1', 'Z_2', 'Z_3', 'nfa_positive', 'nfa_negative'] + controls,
                          f'{y_label}: Z+NFA')
            if r: results.append(r)

    # ── Savings vs. External channel decomposition ──
    print(f"\n{'='*50}")
    print("Savings vs. External Channel Decomposition")
    print(f"{'='*50}")

    decomp_outcomes = {
        'gross_savings_gdp': 'Domestic Savings/GDP',
        'ca_gdp': 'Current Account/GDP',
    }
    for y_var, y_label in decomp_outcomes.items():
        if y_var in df.columns:
            r = run_model(df, y_var, ['Z_1', 'Z_2', 'Z_3'] + controls,
                          f'Decomposition: {y_label}')
            if r: results.append(r)

    # ── Save results ──
    results_df = pd.DataFrame(results)
    results_df.to_csv(OUT_TABLES / "capital_deepening_results.csv", index=False)

    # Formatted table
    with open(OUT_TABLES / "capital_deepening.md", 'w') as f:
        f.write("# Table 3: Demographics → Capital Deepening\n\n")

        # Main results: one table per outcome
        for y_var, y_label in outcomes.items():
            subset = [r for r in results if r.get('y_var') == y_var]
            if not subset:
                continue
            f.write(f"\n## {y_label}\n\n")
            f.write("| Model | N | R² | Z₁ | Z₂ | Z₃ |\n")
            f.write("|-------|---|-----|-----|-----|-----|\n")
            for r in subset:
                z1 = f"{r.get('Z_1_coef', 0):.4f}{stars(r.get('Z_1_p', 1))}" if 'Z_1_coef' in r else ''
                z2 = f"{r.get('Z_2_coef', 0):.4f}{stars(r.get('Z_2_p', 1))}" if 'Z_2_coef' in r else ''
                z3 = f"{r.get('Z_3_coef', 0):.4f}{stars(r.get('Z_3_p', 1))}" if 'Z_3_coef' in r else ''
                f.write(f"| {r['model']} | {r['n_obs']} | {r['r_squared']:.4f} "
                        f"| {z1} | {z2} | {z3} |\n")
            f.write("\n")

        # Decomposition
        f.write("\n## Savings vs. External Channel\n\n")
        f.write("| Channel | N | R² | Z₁ | Z₂ | Z₃ |\n")
        f.write("|---------|---|-----|-----|-----|-----|\n")
        for r in results:
            if 'Decomposition' in r.get('model', ''):
                z1 = f"{r.get('Z_1_coef', 0):.4f}{stars(r.get('Z_1_p', 1))}"
                z2 = f"{r.get('Z_2_coef', 0):.4f}{stars(r.get('Z_2_p', 1))}"
                z3 = f"{r.get('Z_3_coef', 0):.4f}{stars(r.get('Z_3_p', 1))}"
                f.write(f"| {r['model']} | {r['n_obs']} | {r['r_squared']:.4f} "
                        f"| {z1} | {z2} | {z3} |\n")
        f.write("\n*p<0.1, **p<0.05, ***p<0.01\n")

    print("\nPhase 3 complete.")


if __name__ == '__main__':
    main()
